from math import gamma
from turtle import pos
import jax 
import flax.linen as nn
import jax.numpy as jnp
from typing import NamedTuple, Optional,Any,Sequence

from jax import random
from src.utils import *
from src.models.rlt.layers import *
from flax.core.frozen_dict import freeze, unfreeze
from jax.flatten_util import ravel_pytree
import numpy as np

class UOROMemory(nn.Module):
    d_model:int
    kernel_phi:Any
    update_rule:str
    eps:float=1e-5

    @nn.compact
    def __call__(self, inputs,memory,memory_grads=None,tick=None,ret_mem_grad_ax=0):
        """
            ret_mem_grad_ax: Configures which of new memory grads to return, should be either -1 or between 0 and T-1)
        """
        def f(mdl, inputs,memory,tick,memory_grads):
            return mdl(inputs,memory,tick)
            
        def fwd(mdl, inputs,memory,tick,memory_grads):
            output,vjp_fn=nn.vjp(f, mdl, inputs,memory,tick,memory_grads)
            return output,(vjp_fn,memory_grads)

        def bwd(residuals,y_t):
            vjp_fn,memory_grads=residuals
            params_t_1, *inputs_t = vjp_fn(y_t) #c_t wrt current params and c_t wrt inputs
            jvpct_ctminus1,jvpst_stminus1=inputs_t[1] #Second parameter in f()
            ctilde_tminus1,stilde_tminus1,ctildetheta_tminus1,stildetheta_tminus1=memory_grads
            c_infl=jnp.tensordot(jvpct_ctminus1,ctilde_tminus1,axes=ctilde_tminus1.ndim)
            s_infl=jnp.tensordot(jvpst_stminus1,stilde_tminus1,axes=stilde_tminus1.ndim)
            params_t_2_c=jax.tree_map(lambda x: x*c_infl,ctildetheta_tminus1)
            params_t_2_s=jax.tree_map(lambda x: x*s_infl,stildetheta_tminus1)
            params_t=jax.tree_map(lambda x,y,z:x+y+z,params_t_1,params_t_2_c,params_t_2_s)
            return (params_t, *inputs_t)

        uoro_grad = nn.custom_vjp(
            f, forward_fn=fwd, backward_fn=bwd)
        csop_mdl=MemoryLayer(self.d_model,self.kernel_phi,self.update_rule)
        if memory_grads is not None:
            ctilde_tminus1,stilde_tminus1,ctildetheta_tminus1,stildetheta_tminus1=memory_grads
            #Update to new tilde
            def jvp_fn(memory):
                c,s=csop_mdl(inputs,memory,tick)[0]
                return c[ret_mem_grad_ax],s[ret_mem_grad_ax]
            _,(ctilde_t_b,stilde_t_b)=jax.jvp(jvp_fn ,(memory,),((ctilde_tminus1,stilde_tminus1),))
            rng=self.make_rng('random')
            ctilde_t_a=jax.random.choice(rng,jnp.array([+1.0,-1.0]),ctilde_t_b.shape)
            stilde_t_a=jax.random.choice(rng,jnp.array([+1.0,-1.0]),stilde_t_b.shape)
            ctildetheta_t_b=ctildetheta_tminus1
            stildetheta_t_b=stildetheta_tminus1
            #Calculate new terms for theta tilder usomg VJP
            _,cvjpfn=nn.vjp(lambda mdl: mdl(inputs,memory,tick)[0][0][ret_mem_grad_ax], csop_mdl,) 
            ctildetheta_t_a,=cvjpfn(ctilde_t_a)
            _,svjpfn=nn.vjp(lambda mdl: mdl(inputs,memory,tick)[0][1][ret_mem_grad_ax], csop_mdl,) 
            stildetheta_t_a,=svjpfn(stilde_t_a)
            #Calculate the variance minimization terms
            rhoc_0_a=jnp.linalg.norm(ravel_pytree(ctildetheta_t_a)[0])
            rhoc_0_b=jnp.linalg.norm(ctilde_t_a)
            rhoc_0=jnp.sqrt((rhoc_0_a+self.eps)/(rhoc_0_b+self.eps))
            rhoc_1_a=jnp.linalg.norm(ravel_pytree(ctildetheta_t_b)[0])
            rhoc_1_b=jnp.linalg.norm(ctilde_t_b)
            rhoc_1=jnp.sqrt((rhoc_1_a+self.eps)/(rhoc_1_b+self.eps))
            
            rhos_0_a=jnp.linalg.norm(ravel_pytree(stildetheta_t_a)[0])
            rhos_0_b=jnp.linalg.norm(stilde_t_a)
            rhos_0=jnp.sqrt((rhos_0_a+self.eps)/(rhos_0_b+self.eps))
            rhos_1_a=jnp.linalg.norm(ravel_pytree(stildetheta_t_b)[0])
            rhos_1_b=jnp.linalg.norm(stilde_t_b)
            rhos_1=jnp.sqrt((rhos_1_a+self.eps)/(rhos_1_b+self.eps))
            #Calculate the new memory gradients
            ctilde_t=rhoc_0*ctilde_t_a+rhoc_1*ctilde_t_b
            ctildetheta_t_a=dict(freeze(ctildetheta_t_a))
            ctildetheta_t_b=dict(freeze(ctildetheta_t_b))
            stildetheta_t_a=dict(freeze(stildetheta_t_a))
            stildetheta_t_b=dict(freeze(stildetheta_t_b))
            irhoc_0=1/rhoc_0
            irhoc_1=1/rhoc_1
            ctildetheta_t=jax.tree_map(lambda x,y:x*irhoc_0+y*irhoc_1,ctildetheta_t_a,ctildetheta_t_b)
            stilde_t=rhos_0*stilde_t_a+rhos_1*stilde_t_b
            irhos_0=1/rhos_0
            irhos_1=1/rhos_1
            stildetheta_t=jax.tree_map(lambda x,y:x*irhos_0+y*irhos_1,stildetheta_t_a,stildetheta_t_b)
            new_memory_grads=(ctilde_t,stilde_t,ctildetheta_t,stildetheta_t)
        else: new_memory_grads=None
        return uoro_grad(csop_mdl, inputs,memory,tick,memory_grads),new_memory_grads


class RecurrentLinearTransformerEncoderUORO(nn.Module):
    d_model:int
    d_ffc:int
    n_heads:int
    kernel_dim:int #Output dim of the kernel functiom used 
    kernel_phi:Any
    update_rule:str
    ret_mem_grad_ax:int
        
    @nn.compact
    def __call__(self,memory:dict, memory_grads,inputs,layer_id=None,pos_emb_type:str=None,use_dense=False):
        truncation=inputs.shape[0]
        # inputs u_t^{i-1} shape T X d_model, c_tminus1: n_heads,d_model, kernel_dim
        #Memory: c: n_headsXd_model_kernel_dim, s: n_headsXd_model
        

        #Calculation starts here
        #Input-key outer product for n heads
        #Add position embedding + Layerembed
        
        if use_dense:
            inputs_enc=nn.Dense(self.d_model,name='emb_layer')(inputs)
        else:
            inputs_enc=inputs
        if layer_id is not None:
            inputs_embed=LayerEmbLayer(self.d_model,name='layer_emb')(inputs_enc,layer_id) #Layer Embedding
        else:
            inputs_embed=inputs_enc

        if pos_emb_type=='absolute': #Add position embedding if embedding additive sinusodial embeddings
            tick=memory['tick']
            inputs_embed,new_tick=AbsolutePosEmbLayer(self.d_model,name='pos_emb')(inputs_embed,tick)
        elif pos_emb_type=='rotary':
            tick=memory['tick']
            rotary_layer=nn.vmap(RotaryPosEmbLayer,in_axes=(0,None),out_axes=(0,None),variable_axes={'params': 0},
                                split_rngs={'params': True})(self.d_model,name='pos_emb')
         #Input embedding
        inputs_repeat=jnp.repeat(jnp.expand_dims(inputs_embed,axis=0),repeats=self.n_heads,axis=0)
        memory_tuple=(memory['memory']['c'],memory['memory']['s'])
        if pos_emb_type=='rotary':
            csop_mh=nn.vmap(UOROMemory,in_axes=(0,0,0,None), out_axes=((0,None),0),
                                variable_axes={'params': 0},
                                split_rngs={'params': True,'random':True})(d_model=self.d_model,kernel_phi=self.kernel_phi,
                                                            update_rule=self.update_rule,name='csop')
            (state,_),new_memory_grads=csop_mh(inputs_repeat,memory_tuple,memory_grads,tick,ret_mem_grad_ax=self.ret_mem_grad_ax)
        else:
            csop_mh=nn.vmap(UOROMemory,in_axes=(0,0,0), out_axes=((0,),0),
                                variable_axes={'params': 0},
                                split_rngs={'params': True,'random':True})(d_model=self.d_model,kernel_phi=self.kernel_phi,
                                                            update_rule=self.update_rule,name='csop')
            (state,),new_memory_grads=csop_mh(inputs_repeat,memory_tuple,memory_grads,ret_mem_grad_ax=self.ret_mem_grad_ax)
        

        # n_headsXTXd_model_kernel_dim, s: n_headsXTXd_model
        query_layer_mh=nn.vmap(nn.Sequential,in_axes=0, out_axes=0,
                                variable_axes={'params': 0},
                                split_rngs={'params': True})([nn.Dense(self.d_model),self.kernel_phi],name='query')
        
        Q_t=query_layer_mh(inputs_repeat) #n_headsXTXkernel_dim
        
        #Rotate query if rotary embedding
        if pos_emb_type=='rotary':
            Q_t,new_tick=rotary_layer(Q_t,tick)
        
        if self.update_rule=='delta':
            Q_t=Q_t/((jnp.expand_dims(jnp.linalg.norm(Q_t,axis=-1),-1))+ 1e-6)

        c,s=state

        #Apply attention  
        attn_out=jax.vmap(attention_func)(c,s,Q_t)   
        
        #Combine output of n heads
        attn_out=jnp.transpose(attn_out,(1,0,2)).reshape(truncation,-1) #TXd_model*n_heads
        attn_out=nn.Dense(self.d_model)(attn_out) #TXd_model
        attn_out=nn.LayerNorm()(attn_out+inputs_embed)
        #Add only previous output because this is a decoder
        ffc=nn.Sequential([nn.Dense(self.d_ffc),jax.nn.relu,nn.Dense(self.d_model)])
        out=nn.LayerNorm()(ffc(attn_out)+attn_out)

        memory=Memory(c=jnp.transpose(c,(1,0,2,3)),s=jnp.transpose(s,(1,0,2))) #TXc, TXs

        if pos_emb_type is not None: 
            new_memory={
                'memory':memory,
                'tick':new_tick,
            }
        else:
             new_memory={
                'memory':memory,
            }
        return out,new_memory,new_memory_grads


    @staticmethod    
    def initialize_memory(n_heads,d_model,kernel_dim,pos_emb_type,update_rule,kernel_phi):
        c_tminus1=jnp.zeros((n_heads,d_model,kernel_dim))
        s_tminus1=jnp.zeros((n_heads,d_model))
        memory=Memory(c=c_tminus1,s=s_tminus1)
        #Create Memory grads
        c_tilde=jnp.zeros((n_heads,d_model,kernel_dim))
        s_tilde=jnp.zeros((n_heads,kernel_dim,))
        params=MemoryLayer(d_model,kernel_phi,update_rule).init(random.PRNGKey(0), jnp.zeros((1,d_model)),
                                               (jnp.zeros((d_model,kernel_dim)),jnp.zeros((kernel_dim))))
        ctilde_theta=dict(jax.tree_map(lambda x: jnp.zeros((n_heads,*x.shape)),params))
        stilde_theta=dict(jax.tree_map(lambda x: jnp.zeros((n_heads,*x.shape)),params))
        memory_grads=(c_tilde,s_tilde,ctilde_theta,stilde_theta)
        if pos_emb_type=='rotary':
            tick=RotaryPosEmbLayer.initialize_rotation_matrix(d_model)
            return {'memory':memory,'tick':tick},memory_grads
        elif pos_emb_type=='absolute':
            tick=jnp.array(1,dtype=np.uint)
            return {'memory':memory,'tick':tick},memory_grads
        else:
            return {'memory':memory},memory_grads
        

    @staticmethod
    def create_memory(c_tminus1,s_tminus1,tick):
        return {'memory':Memory(c=c_tminus1,s=s_tminus1),'tick':tick}



class RecurrentLinearTransformerUORO(nn.Module):
    n_layers:int
    d_model:int
    d_ffc:int
    n_heads:int 
    kernel_dim:int
    kernel_phi:Any=eluplus1
    pos_emb_type:str='rotary'
    use_layer_emb:bool=True
    update_rule:str='gated'
    ret_mem_grad_ax:int=0

    @nn.compact
    def __call__(self,inputs,last_memory,last_memory_grads):
        """
            Call the n layered Transformer module prediction
            
            last_memory: list(KVOPSum) c: n_headsXd_model_kernel_dim, s: n_headsXd_model
            Returns u_i, new_memory: c: n_headsXTXd_model_kernel_dim, s: n_headsXTXd_model
        """
        u_i=inputs
        new_memory=[]
        new_memory_grads=[]
        for layer,(memory,memory_grads) in enumerate(zip(last_memory,last_memory_grads)):
            encoder=RecurrentLinearTransformerEncoderUORO(d_model=self.d_model,d_ffc=self.d_ffc,n_heads=self.n_heads,
                                                   kernel_phi=self.kernel_phi,kernel_dim=self.kernel_dim,
                                                    update_rule=self.update_rule,name='layer%d'%(layer),ret_mem_grad_ax=self.ret_mem_grad_ax)
            u_i,memory_updated,memory_grads_updated=encoder(memory,memory_grads,u_i,layer+1 if self.use_layer_emb else None, 
                                    pos_emb_type=(self.pos_emb_type if layer==0 else None), use_dense=(True if layer==0 else False))
            new_memory.append(memory_updated)
            new_memory_grads.append(memory_grads_updated)
        return u_i,new_memory,new_memory_grads
    
    @staticmethod
    def initialize_memory(n_layers,n_heads,d_model,kernel_dim,pos_emb_type,update_rule,kernel_phi):
        memory_list=[]
        grads_list=[]
        for layer in range(1,n_layers+1):
            memory,memory_grads=RecurrentLinearTransformerEncoderUORO.initialize_memory(n_heads,
                                                                                d_model,kernel_dim,pos_emb_type,update_rule,
                                                                                kernel_phi)
            memory_list.append(memory)
            grads_list.append(memory_grads)
        return memory_list,grads_list